[Rust] Amazon Transcribe Streamingを使ってmp3をテキスト化する

[Rust] Amazon Transcribe Streamingを使ってmp3をテキスト化する

Clock Icon2024.10.24

Introduction

この前は動画内でしゃべった言葉をテキスト検索するという記事を書きましたが、今度はmp3ファイルから音声データをstreamingでテキスト化するデモを
Rustで実装しました。

Environment

  • Rust: 1.81.0

AWSアカウントはセットアップ済みで、Transcribeが使用可能な状態とします。

Setup

まずはcargoでプロジェクトを作成します。

% cargo new transcribe-demo
% cd transcribe-demo

必要な依存関係をCargo.tomlに追加します。

[dependencies]
aws-config = "0.55.3"
aws-sdk-transcribestreaming = "0.55.3"
tokio = { version = "1.0", features = ["full"] }
symphonia = { version = "0.5.3", features = ["mp3"] }
async-stream = "0.3.5"
futures = "0.3.28"
clap = { version = "4.3.21", features = ["derive"] }
tracing-subscriber = "0.3.17"

stream処理をするので、aws-sdk-transcribestreamingを使います。

Try

このデモの主な機能は以下の通り。

  1. mp3ファイルをPCMデータへ変換
  2. ストリーミング形式でのAWS Transcribeへの送信
  3. リアルタイムでの文字起こし結果の取得

2024年10月現在、Transcribeは、mp3データをstreamingで処理できません。(一括処理ならOK)
そのため、最初にmp3からPCMデータへ変換してから処理します。

※コード全文は本稿の最後に記載しています

mp3ファイルをPCMデータへ変換

stream_mp3_to_mono_pcm関数では、Symphoniaを使用してmp3を
PCMデータに変換します。

fn stream_mp3_to_mono_pcm(
    input_file: &str,
) -> (
    impl Stream<Item = Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync,
    u32,
) {
    // mp3デコード処理
    // ステレオ->モノラル変換
    // チャンク分割処理
    // etc...
}

この関数では、mp3ファイルをloadして指定したサイズのchunkに分割してPCMのストリームとして出力します。

AWS Transcribeと連携

次に、変換したPCMストリームをAWS Transcribeに順々に送信します。
そして、順次文字起こし結果を取得します。

let mut output = client
    .start_stream_transcription()
    .language_code(LanguageCode::JaJp)
    .media_sample_rate_hertz(sample_rate as i32)
    .media_encoding(MediaEncoding::Pcm)
    .audio_stream(input_stream.into())
    .send()
    .await?;

文字起こし結果の処理

文字起こし結果は、部分的な結果と確定結果の2種類が返されます:

while let Some(event) = output.transcript_result_stream.recv().await? {
    match event {
        TranscriptResultStream::TranscriptEvent(transcript_event) => {
            // 部分的な結果と確定結果の処理
            // ...
        }
        otherwise => panic!("unexpected event type: {:?}", otherwise),
    }
}

実行

下記コマンドで実行できます。

cargo run --bin main -- -a <mp3ファイルのパス>

実行すると、リアルタイムで文字起こしが行われ、結果が表示されます。

今回はここにある、
「G-01 : CM原稿(せっけん)」をサンプルで読み上げしてみます。
実行すると、下記のように順番に少しずつ読み上げされます。

% cargo run --bin main -- -a example.mp3

・・・

Tracing initialization took: 1.699291ms
Command line parsing took: 1.156125ms
Region provider setup took: 2.771ms
AWS client setup took: 121.126666ms
MP3 to PCM stream setup took: 659.583µs
PCM stream creation took: 665.083µs

・・・

Event processing took: 7.042µs
Event processing took: 6.958µs
Transcribed: 無添加のしゃぼん玉石鹸ならもう安心。
Event processing took: 48µs
・・・
Event processing took: 11.25µs
Event processing took: 8.959µs
Event processing took: 11.167µs
Transcribed: 天然の保湿成分が含まれるため、肌に潤いを与え、健やかに保ちます。
Event processing took: 60.875µs
Event processing took: 1.458µs
・・・
Event processing took: 14.042µs
Event processing took: 9.125µs
Transcribed: お肌のことでお悩みの方はぜひ一度、無添加しゃぼん玉石鹸をお試しください。
Event processing took: 33.042µs
Event processing took: 625ns

・・・

#最後にまとめて表示
Fully transcribed message:

無添加のしゃぼん玉石鹸ならもう安心。
天然の保湿成分が含まれるため、肌に潤いを与え、健やかに保ちます。
お肌のことでお悩みの方はぜひ一度、無添加しゃぼん玉石鹸をお試しください。
お求めは、ゼロ一二ゼロ、ゼロゼロ五五九五まで。

Total execution time: 16.21453325s

Summary

AWS Transcribe streamingでmp3から文字起こしをしてみました。
なお、mp3でも一括変換であれば今回のようなpcm変換は必要ないのでもっと簡単にできます。

コード全文

use async_stream::stream;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_transcribestreaming::primitives::Blob;
use aws_sdk_transcribestreaming::types::{
    AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream,
};
use aws_sdk_transcribestreaming::{config::Region, meta::PKG_VERSION, Client};
use clap::Parser;
use futures::Stream;
use std::fs::File;
use std::io::Write;
use std::time::Instant;
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;

fn create_docx(text: &str, output_path: &str) -> Result<(), DocxError> {
    let path = std::path::Path::new(output_path);
    let file = File::create(&path).unwrap();
    let mut docx = Docx::new();

    // テキストを段落に分割
    for paragraph in text.split("\n") {
        docx = docx.add_paragraph(Paragraph::new().add_run(Run::new().add_text(paragraph)));
    }

    Ok(())
}

#[derive(Debug, Parser)]
struct Opt {
    #[structopt(short, long)]
    region: Option<String>,

    /// 処理する音声ファイルパス
    #[structopt(short, long)]
    audio_file: String,

    #[structopt(short, long)]
    verbose: bool,
}

// PCMデータのchunk size(byte単位)
const CHUNK_SIZE: usize = 8192;

/// mp3からモノラルPCMデータをストリーミング
fn stream_mp3_to_mono_pcm(
    input_file: &str,
) -> (
    impl Stream<Item = Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync,
    u32,
) {
    let start = Instant::now();

    let file = File::open(input_file).unwrap();
    let mss = MediaSourceStream::new(Box::new(file), Default::default());

    let mut hint = Hint::new();
    hint.with_extension("mp3");

    let format_opts: FormatOptions = Default::default();
    let metadata_opts: MetadataOptions = Default::default();
    let decoder_opts: DecoderOptions = Default::default();

    let probed = symphonia::default::get_probe()
        .format(&hint, mss, &format_opts, &metadata_opts)
        .unwrap();

    let mut format = probed.format;
    let track = format.default_track().unwrap();
    let mut decoder = symphonia::default::get_codecs()
        .make(&track.codec_params, &decoder_opts)
        .unwrap();

    let mut sample_buf: Option<SampleBuffer<i16>> = None;
    let sample_rate = track.codec_params.sample_rate.unwrap();

    // PCMストリームを作成
    let stream = async_stream::try_stream! {
        while let Ok(packet) = format.next_packet() {
            let decoded = decoder.decode(&packet)?;
            let spec = *decoded.spec();

            if sample_buf.is_none() {
                sample_buf = Some(SampleBuffer::new(decoded.capacity() as u64, spec));
            }

            let sample_buf = sample_buf.as_mut().unwrap();
            sample_buf.copy_interleaved_ref(decoded);
            let num_channels = spec.channels.count();

            let mut mono_samples = Vec::new();
            for i in 0..sample_buf.samples().len() / num_channels {
                let mono_sample = if num_channels > 1 {
                    // ステレオの場合、左右チャンネルの平均を取る
                    let left = sample_buf.samples()[i * 2] as i32;
                    let right = sample_buf.samples()[i * 2 + 1] as i32;
                    ((left + right) / 2) as i16
                } else {
                    // モノラルの場合はそのまま
                    sample_buf.samples()[i]
                };
                mono_samples.extend_from_slice(&mono_sample.to_le_bytes());
            }

            // モノラルPCMデータをチャンクに分割してyield
            for chunk in mono_samples.chunks(CHUNK_SIZE) {
                yield chunk.to_vec();
            }
        }
    };

    let elapsed = start.elapsed();
    println!("MP3 to PCM stream setup took: {:?}", elapsed);
    (stream, sample_rate)
}

/// 関数の実行時間を計測
macro_rules! time_it {
    ($desc:expr, $func:expr) => {{
        let start = Instant::now();
        let result = $func;
        println!("{} took: {:?}", $desc, start.elapsed());
        result
    }};
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let total_start = Instant::now();

    time_it!("Tracing initialization", {
        tracing_subscriber::fmt::init();
    });

    let Opt {
        region,
        audio_file,
        verbose,
    } = time_it!("Command line parsing", Opt::parse());

    let region_provider = time_it!("Region provider setup", {
        RegionProviderChain::first_try(region.map(Region::new))
            .or_default_provider()
            .or_else(Region::new("us-west-2"))
    });

    if verbose {
        println!("Transcribe client version: {}", PKG_VERSION);
        println!(
            "Region:                    {}",
            region_provider.region().await.unwrap().as_ref()
        );
        println!("Audio filename:            {}", &audio_file);
        println!();
    }

    let client = time_it!("AWS client setup", {
        let shared_config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
        Client::new(&shared_config)
    });

    // MP3からPCMストリームを作成
    let (pcm_stream, sample_rate) = time_it!("PCM stream creation", {
        stream_mp3_to_mono_pcm(&audio_file)
    });

    // トランスクリプション用の入力ストリームを作成
    let input_stream = time_it!("Input stream creation", {
        stream! {
            for await chunk in pcm_stream {
                match chunk {
                    Ok(data) => {
                        // PCMデータをAudioEventに変換
                        yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(data)).build()));
                    }
                    Err(e) => {
                        eprintln!("Error streaming PCM data: {:?}", e);
                        break;
                    }
                }
            }
        }
    });

    // トランスクリプションの開始
    let mut output = time_it!("Transcription setup", {
        client
            .start_stream_transcription()
            .language_code(LanguageCode::JaJp)
            .media_sample_rate_hertz(sample_rate as i32)
            .media_encoding(MediaEncoding::Pcm)
            .audio_stream(input_stream.into())
            .send()
            .await?
    });

    // トランスクリプション結果の処理
    let mut full_message = String::new();
    time_it!("Total transcription processing", {
        while let Some(event) = output.transcript_result_stream.recv().await? {
            time_it!("Event processing", {
                match event {
                    TranscriptResultStream::TranscriptEvent(transcript_event) => {
                        let transcript = transcript_event.transcript.unwrap();
                        for result in transcript.results.unwrap_or_default() {
                            if result.is_partial {
                                if verbose {
                                    println!("Partial: {:?}", result);
                                }
                            } else {
                                let first_alternative = &result.alternatives.as_ref().unwrap()[0];
                                let transcribed_text =
                                    first_alternative.transcript.as_ref().unwrap();
                                full_message += transcribed_text;
                                full_message.push('\n');
                                println!("Transcribed: {}", transcribed_text);
                            }
                        }
                    }
                    otherwise => panic!("received unexpected event type: {:?}", otherwise),
                }
            });
        }
    });

    println!("\nFully transcribed message:\n\n{}", full_message);
    println!("Total execution time: {:?}", total_start.elapsed());
    Ok(())
}

Appendix : Javascript版

ついでにnodejs版。
rustの前に試したやつです。

const fs = require('fs');
const ffmpeg = require('fluent-ffmpeg');
const { PassThrough, Transform } = require('stream');
const {
  TranscribeStreamingClient,
  StartStreamTranscriptionCommand
} = require('@aws-sdk/client-transcribe-streaming');

// AWS regin
const REGION = 'ap-northeast-1';

const client = new TranscribeStreamingClient({ region: REGION });

//16kb chunkにsplit
class ChunkSplitter extends Transform {
  constructor(options) {
    super(options);
    this.maxChunkSize = 16 * 1024; // 16kb
    this.buffer = Buffer.alloc(0);
  }

  _transform(chunk, encoding, callback) {
    this.buffer = Buffer.concat([this.buffer, chunk]);

    while (this.buffer.length >= this.maxChunkSize) {
      const chunkToPush = this.buffer.slice(0, this.maxChunkSize);
      //console.log(`Pushing chunk of size: ${chunkToPush.length} bytes`);
      this.push(chunkToPush);
      this.buffer = this.buffer.slice(this.maxChunkSize);
    }

    callback();
  }

  _flush(callback) {
    if (this.buffer.length > 0) {
      //console.log(`Pushing final chunk of size: ${this.buffer.length} bytes`);
      this.push(this.buffer);
      this.buffer = Buffer.alloc(0);
    }
    callback();
  }
}

// オーディオストリームを生成
async function* generateAudioStream(filePath) {
  const passThrough = new PassThrough();

  // FFmpeg を使用して MP3 から PCM に変換
  ffmpeg(filePath)
    .inputFormat('mp3')
    .audioFrequency(16000)    // サンプリングレート
    .audioChannels(1)         // チャンネル数
    .format('s16le')          // PCM形式
    .on('error', (err) => {
      console.error('FFmpeg error:', err);
      passThrough.destroy(err);
    })
    .on('end', () => {
      console.log('FFmpeg stream ended');
      passThrough.end();
    })
    .pipe(passThrough);

  // split chunk
  const splitter = new ChunkSplitter();
  passThrough.pipe(splitter);

  for await (const chunk of splitter) {
    yield { AudioEvent: { AudioChunk: chunk } };
  }
}

async function transcribeMP3(filePath) {
  const command = new StartStreamTranscriptionCommand({
    LanguageCode: 'ja-JP',            // 音声の言語コード
    MediaSampleRateHertz: 16000,      // PCMのサンプリングレート
    MediaEncoding: 'pcm',             // オーディオのエンコーディング形式
    AudioStream: generateAudioStream(filePath), 
  });

  try {
    // Transcribe Streaming
    const response = await client.send(command);

    for await (const event of response.TranscriptResultStream) {
      if (event.TranscriptEvent) {
        const results = event.TranscriptEvent.Transcript.Results;
        results.forEach(result => {
          if (!result.IsPartial) {
            console.log('Transcription:', result.Alternatives[0].Transcript);
          }
        });
      }
    }
  } catch (error) {
    console.error('Error transcribing audio:', error);
  }
}

// 使用例
transcribeMP3('<mp3ファイルパス>');

References

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.